package onnxruntime

import (
	"errors"
	"fmt"
	"reflect"
	"unsafe"
)

// #cgo CFLAGS: -O2 -g
//
// #include "onnxruntime_wrapper.h"
import "C"

type FloatData interface {
	~float32 | ~float64
}

type IntData interface {
	~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64
}

// This is used as a type constraint for the generic Tensor type.
type TensorData interface {
	FloatData | IntData
}

// Returns the ONNX enum value used to indicate TensorData type T.
func GetTensorElementDataType[T TensorData]() C.ONNXTensorElementDataType {
	// Sadly, we can't do type assertions to get underlying types, so we need
	// to use reflect here instead.
	var v T
	kind := reflect.ValueOf(v).Kind()
	switch kind {
	case reflect.Float64:
		return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE
	case reflect.Float32:
		return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
	case reflect.Int8:
		return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8
	case reflect.Uint8:
		return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8
	case reflect.Int16:
		return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16
	case reflect.Uint16:
		return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16
	case reflect.Int32:
		return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
	case reflect.Uint32:
		return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32
	case reflect.Int64:
		return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
	case reflect.Uint64:
		return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64
	}
	return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
}

// Wraps the ONNXTEnsorElementDataType enum in C.
type TensorElementDataType int

const (
	TensorElementDataTypeUndefined = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
	TensorElementDataTypeFloat     = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
	TensorElementDataTypeUint8     = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8
	TensorElementDataTypeInt8      = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8
	TensorElementDataTypeUint16    = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16
	TensorElementDataTypeInt16     = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16
	TensorElementDataTypeInt32     = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
	TensorElementDataTypeInt64     = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
	TensorElementDataTypeString    = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
	TensorElementDataTypeBool      = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
	TensorElementDataTypeFloat16   = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
	TensorElementDataTypeDouble    = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE
	TensorElementDataTypeUint32    = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32
	TensorElementDataTypeUint64    = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64

	// Not supported by onnxruntime (as of onnxruntime version 1.16.1)
	TensorElementDataTypeComplex64 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
	// Not supported by onnxruntime (as of onnxruntime version 1.16.1)
	TensorElementDataTypeComplex128 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128

	// Non-IEEE floating-point format based on IEEE754 single-precision
	TensorElementDataTypeBFloat16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16

	// 8-bit float types, introduced in onnx 1.14.  See
	// https://onnx.ai/onnx/technical/float8.html
	TensorElementDataTypeFloat8E4M3FN   = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN
	TensorElementDataTypeFloat8E4M3FNUZ = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ
	TensorElementDataTypeFloat8E5M2     = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2
	TensorElementDataTypeFloat8E5M2FNUZ = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ
)

func (t TensorElementDataType) String() string {
	switch t {
	case TensorElementDataTypeUndefined:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED"
	case TensorElementDataTypeFloat:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT"
	case TensorElementDataTypeUint8:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"
	case TensorElementDataTypeInt8:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"
	case TensorElementDataTypeUint16:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"
	case TensorElementDataTypeInt16:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"
	case TensorElementDataTypeInt32:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"
	case TensorElementDataTypeInt64:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"
	case TensorElementDataTypeString:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"
	case TensorElementDataTypeBool:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"
	case TensorElementDataTypeFloat16:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"
	case TensorElementDataTypeDouble:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"
	case TensorElementDataTypeUint32:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"
	case TensorElementDataTypeUint64:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"
	case TensorElementDataTypeComplex64:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64"
	case TensorElementDataTypeComplex128:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128"
	case TensorElementDataTypeBFloat16:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"
	case TensorElementDataTypeFloat8E4M3FN:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN"
	case TensorElementDataTypeFloat8E4M3FNUZ:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ"
	case TensorElementDataTypeFloat8E5M2:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2"
	case TensorElementDataTypeFloat8E5M2FNUZ:
		return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ"
	}
	return fmt.Sprintf("Unknown tensor element data type: %d", int(t))
}

// This satisfies the ArbitraryTensor interface, but is intended to allow users
// to provide tensors of types that may not be supported by the generic
// typed Tensor[T] struct. Instead, CustomDataTensors are backed by a slice of
// bytes, using a user-provided shape and type from the
// ONNXTensorElementDataType enum.
type CustomDataTensor struct {
	data     []byte
	dataType C.ONNXTensorElementDataType
	shape    Shape
	ortValue *C.OrtValue
}

// Creates and returns a new CustomDataTensor using the given bytes as the
// underlying data slice. Apart from ensuring that the provided data slice is
// non-empty, this function mostly delegates validation of the provided data to
// the C onnxruntime library. For example, it is the caller's responsibility to
// ensure that the provided dataType and data slice are valid and correctly
// sized for the specified shape. If this returns successfully, the caller must
// call the returned tensor's Destroy() function to free it when no longer in
// use.
func NewCustomDataTensor(s Shape, data []byte,
	dataType TensorElementDataType) (*CustomDataTensor, error) {
	if !IsInitialized() {
		return nil, ErrorNotInitialized
	}
	e := s.Validate()
	if e != nil {
		return nil, fmt.Errorf("invalid tensor shape: %w", e)
	}
	if len(data) == 0 {
		return nil, errors.New("a CustomDataTensor requires at least one byte of data")
	}
	dt := C.ONNXTensorElementDataType(dataType)
	var ortValue *C.OrtValue

	status := C.CreateOrtTensorWithShape(unsafe.Pointer(&data[0]),
		C.size_t(len(data)), (*C.int64_t)(unsafe.Pointer(&s[0])),
		C.int64_t(len(s)), ortMemoryInfo, dt, &ortValue)
	if status != nil {
		return nil, fmt.Errorf("ORT API error creating tensor: %s",
			statusToError(status))
	}
	toReturn := CustomDataTensor{
		data:     data,
		dataType: dt,
		shape:    s.Clone(),
		ortValue: ortValue,
	}
	return &toReturn, nil
}

func (t *CustomDataTensor) Destroy() error {
	C.ReleaseOrtValue(t.ortValue)
	t.ortValue = nil
	t.data = nil
	t.shape = nil
	t.dataType = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
	return nil
}

func (t *CustomDataTensor) DataType() C.ONNXTensorElementDataType {
	return t.dataType
}

func (t *CustomDataTensor) GetShape() Shape {
	return t.shape.Clone()
}

func (t *CustomDataTensor) GetInternals() *TensorInternalData {
	return &TensorInternalData{
		OrtValue: t.ortValue,
	}
}

// Sets all bytes in the data slice to 0.
func (t *CustomDataTensor) ZeroContents() {
	C.memset(unsafe.Pointer(&t.data[0]), 0, C.size_t(len(t.data)))
}

// Returns the same slice that was passed to NewCustomDataTensor.
func (t *CustomDataTensor) GetData() []byte {
	return t.data
}
